﻿// Accord Imaging Library
// The Accord.NET Framework
// http://accord-framework.net
//
// Copyright © César Souza, 2009-2017
// cesarsouza at gmail.com
//
//    This library is free software; you can redistribute it and/or
//    modify it under the terms of the GNU Lesser General Public
//    License as published by the Free Software Foundation; either
//    version 2.1 of the License, or (at your option) any later version.
//
//    This library is distributed in the hope that it will be useful,
//    but WITHOUT ANY WARRANTY; without even the implied warranty of
//    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
//    Lesser General Public License for more details.
//
//    You should have received a copy of the GNU Lesser General Public
//    License along with this library; if not, write to the Free Software
//    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
//

namespace Accord.MachineLearning
{
    using Accord.Math;
    using System;
    using System.Collections.Generic;
    using System.Threading;
    using System.Linq;
    using Accord.Compat;
    using System.Threading.Tasks;
    using Accord.Statistics.Distributions.Univariate;
    using System.Diagnostics;
    using Accord.Statistics.Distributions.Fitting;

    /// <summary>
    ///   Base class for <see cref="BagOfWords">Bag of Audiovisual Words</see> implementations.
    /// </summary>
    /// 
    [Serializable]
    public class BaseBagOfWords<TModel, TPoint, TFeature, TClustering, TExtractor, TInput> :
        ParallelLearningBase, ITransform<TInput, int[]>, ITransform<TInput, double[]>
        where TPoint : IFeatureDescriptor<TFeature>
        where TModel : BaseBagOfWords<TModel, TPoint, TFeature, TClustering, TExtractor, TInput>
        where TClustering : IUnsupervisedLearning<IClassifier<TFeature, int>, TFeature, int>
        where TExtractor : IFeatureExtractor<TPoint, TInput>, ICloneable
    {

        private IClassifier<TFeature, int> classifier;

        /// <summary>
        ///   Gets the number of words in this codebook.
        /// </summary>
        /// 
        public int NumberOfWords { get; private set; }

        /// <summary>
        ///   Gets or sets the maximum number of descriptors that should be used 
        ///   to learn the codebook. Default is 0 (meaning to use all descriptors).
        /// </summary>
        /// 
        /// <value>The maximum number of samples.</value>
        /// 
        public int NumberOfDescriptors { get; set; }

        /// <summary>
        ///   Gets or sets the maximum number of descriptors per image that should be 
        ///   used to learn the codebook. Default is 0 (meaning to use all descriptors).
        /// </summary>
        /// 
        /// <value>The maximum number of samples per image.</value>
        /// 
        public int MaxDescriptorsPerInstance { get; set; }

        /// <summary>
        ///   Gets the clustering algorithm used to create this model.
        /// </summary>
        /// 
        public TClustering Clustering { get; private set; }

        /// <summary>
        ///   Gets the feature extractor used to identify features in the input data.
        /// </summary>
        /// 
        public TExtractor Detector { get; private set; }

        /// <summary>
        /// Gets the number of inputs accepted by the model.
        /// </summary>
        /// <value>The number of inputs.</value>
        public int NumberOfInputs
        {
            get { return -1; }
        }

        /// <summary>
        /// Gets the number of outputs generated by the model.
        /// </summary>
        /// <value>The number of outputs.</value>
        public int NumberOfOutputs
        {
            get { return NumberOfWords; }
        }

        /// <summary>
        /// Gets statistics about the last codebook learned.
        /// </summary>
        /// 
        public BagOfWordsStatistics Statistics { get; private set; }

        /// <summary>
        ///   Constructs a new <see cref="BaseBagOfWords{TModel, TPoint, TFeature, TClustering, TExtractor, TInput}"/>.
        /// </summary>
        /// 
        protected BaseBagOfWords()
        {
        }

        /// <summary>
        ///   Initializes this instance.
        /// </summary>
        /// 
        protected void Init(TExtractor detector, TClustering algorithm)
        {
            this.Clustering = algorithm;
            this.Detector = detector;

            IParallel p = algorithm as IParallel;
            if (p != null)
                this.ParallelOptions = p.ParallelOptions;
        }



        #region Transform

        /// <summary>
        /// Applies the transformation to a set of input vectors,
        /// producing an associated set of output vectors.
        /// </summary>
        /// <param name="input">The input data to which
        /// the transformation should be applied.</param>
        /// <param name="result">The location to where to store the
        /// result of this transformation.</param>
        /// <returns>The output generated by applying this
        /// transformation to the given input.</returns>
        public int[] Transform(IEnumerable<TPoint> input, int[] result)
        {
            IList<TPoint> list = input.ToList();

            // Detect all activation centroids
            Parallel.For(0, list.Count, ParallelOptions, i =>
            {
                TFeature x = list[i].Descriptor;
                int j = classifier.Decide(x);
                Interlocked.Increment(ref result[j]);
            });

            return result;
        }

        /// <summary>
        /// Applies the transformation to a set of input vectors,
        /// producing an associated set of output vectors.
        /// </summary>
        /// <param name="input">The input data to which
        /// the transformation should be applied.</param>
        /// <param name="result">The location to where to store the
        /// result of this transformation.</param>
        /// <returns>The output generated by applying this
        /// transformation to the given input.</returns>
        public double[] Transform(IEnumerable<TPoint> input, double[] result)
        {
            IList<TPoint> list = input.ToList();

            // Detect all activation centroids
            Parallel.For(0, list.Count, ParallelOptions, i =>
            {
                TFeature x = list[i].Descriptor;
                int j = classifier.Decide(x);
                InterlockedEx.Increment(ref result[j]);
            });

            return result;
        }

        /// <summary>
        /// Applies the transformation to a set of input vectors,
        /// producing an associated set of output vectors.
        /// </summary>
        /// <param name="input">The input data to which
        /// the transformation should be applied.</param>
        /// <param name="result">The location to where to store the
        /// result of this transformation.</param>
        /// <returns>The output generated by applying this
        /// transformation to the given input.</returns>
        public double[] Transform(TInput input, double[] result)
        {
            return Transform(Detector.Transform(input), result);
        }

        /// <summary>
        /// Applies the transformation to a set of input vectors,
        /// producing an associated set of output vectors.
        /// </summary>
        /// <param name="input">The input data to which
        /// the transformation should be applied.</param>
        /// <param name="result">The location to where to store the
        /// result of this transformation.</param>
        /// <returns>The output generated by applying this
        /// transformation to the given input.</returns>
        public int[] Transform(TInput input, int[] result)
        {
            return Transform(Detector.Transform(input), result);
        }


        /// <summary>
        /// Applies the transformation to an input, producing an associated output.
        /// </summary>
        /// <param name="input">The input data to which the transformation should be applied.</param>
        /// <returns>The output generated by applying this transformation to the given input.</returns>
        public double[] Transform(List<TPoint> input)
        {
            return Transform(input, new double[NumberOfWords]);
        }

        /// <summary>
        /// Applies the transformation to an input, producing an associated output.
        /// </summary>
        /// <param name="input">The input data to which the transformation should be applied.</param>
        /// <returns>The output generated by applying this transformation to the given input.</returns>
        public double[] Transform(TInput input)
        {
            return Transform(input, new double[NumberOfWords]);
        }

        int[] ICovariantTransform<TInput, int[]>.Transform(TInput input)
        {
            return Transform(input, new int[NumberOfWords]);
        }

        /// <summary>
        /// Applies the transformation to a set of input vectors,
        /// producing an associated set of output vectors.
        /// </summary>
        /// <param name="input">The input data to which
        /// the transformation should be applied.</param>
        /// <returns>The output generated by applying this
        /// transformation to the given input.</returns>
        public double[][] Transform(TInput[] input)
        {
            return Transform(input, Jagged.Zeros(input.Length, NumberOfWords));
        }

        int[][] ICovariantTransform<TInput, int[]>.Transform(TInput[] input)
        {
            return Transform(input, Jagged.Zeros<int>(input.Length, NumberOfWords));
        }




        /// <summary>
        ///   Executes a parallel for using the feature detector in a thread-safe way.
        /// </summary>
        protected void For(int fromInclusive, int toExclusive, Action<int, TExtractor> action)
        {
            if (ParallelOptions.MaxDegreeOfParallelism == 1)
            {
                for (int i = fromInclusive; i < toExclusive; i++)
                    action(i, Detector);
                return;
            }

            Parallel.For(fromInclusive, toExclusive, ParallelOptions,

                // If we don't clone the detector, we run in race conditions
                () => (TExtractor)Detector.Clone(),

                (i, state, detector) =>
                {
                    // here, each thread has its own copy of the detector
                    action(i, detector);
                    return detector;
                },

                (detector) =>
                {
                    var d = detector as IDisposable;
                    if (d != null)
                        d.Dispose();
                });
        }

        /// <summary>
        /// Applies the transformation to a set of input vectors,
        /// producing an associated set of output vectors.
        /// </summary>
        /// <param name="input">The input data to which
        /// the transformation should be applied.</param>
        /// <param name="result">The location to where to store the
        /// result of this transformation.</param>
        /// <returns>The output generated by applying this
        /// transformation to the given input.</returns>
        public double[][] Transform(TInput[] input, double[][] result)
        {
            For(0, input.Length, (i, detector) =>
            {
                Transform(detector.Transform(input[i]), result[i]);
            });

            return result;
        }

        /// <summary>
        /// Applies the transformation to a set of input vectors,
        /// producing an associated set of output vectors.
        /// </summary>
        /// <param name="input">The input data to which
        /// the transformation should be applied.</param>
        /// <param name="result">The location to where to store the
        /// result of this transformation.</param>
        /// <returns>The output generated by applying this
        /// transformation to the given input.</returns>
        public int[][] Transform(TInput[] input, int[][] result)
        {
            For(0, input.Length, (i, detector) =>
            {
                Transform(detector.Transform(input[i]), result[i]);
            });

            return result;
        }

        #endregion



        #region Learn
        /// <summary>
        /// Learns a model that can map the given inputs to the desired outputs.
        /// </summary>
        /// <param name="x">The model inputs.</param>
        /// <param name="weights">The weight of importance for each input sample.</param>
        /// <returns>A model that has learned how to produce suitable outputs
        /// given the input data <paramref name="x" />.</returns>
        public TModel Learn(TFeature[] x, double[] weights = null)
        {
            if (weights != null && x.Length != weights.Length)
                throw new DimensionMismatchException("weights", "The weights vector should have the same length as x.");

            if (x.Length <= NumberOfWords)
            {
                throw new InvalidOperationException("Not enough data points to cluster. Please try "
                    + "to adjust the feature extraction algorithm to generate more points.");
            }

            this.Statistics = new BagOfWordsStatistics()
            {
                TotalNumberOfDescriptors = x.Length,
            };

            return learn(x, weights);
        }


        /// <summary>
        /// Learns a model that can map the given inputs to the desired outputs.
        /// </summary>
        /// <param name="x">The model inputs.</param>
        /// <param name="weights">The weight of importance for each input sample.</param>
        /// <returns>A model that has learned how to produce suitable outputs
        /// given the input data <paramref name="x" />.</returns>
        public TModel Learn(TInput[] x, double[] weights = null)
        {
            return InnerLearn(x, weights, (xi, detector) => detector.Transform(xi));
        }

        /// <summary>
        ///   Generic learn method implementation that should work for any input type.
        ///   This method is useful for re-using code between methods that accept Bitmap,
        ///   BitmapData, UnmanagedImage, filenames as strings, etc.
        /// </summary>
        /// 
        /// <typeparam name="T">The input type.</typeparam>
        /// 
        /// <param name="x">The inputs.</param>
        /// <param name="weights">The weights.</param>
        /// <param name="extractor">A function that knows how to process the input 
        ///   and extract features from them.</param>
        /// 
        /// <returns>The trained model.</returns>
        /// 
        protected TModel InnerLearn<T>(T[] x, double[] weights,
            Func<T, TExtractor, IEnumerable<TPoint>> extractor)
        {
            var descriptorsPerInstance = new TFeature[x.Length][];
            var totalDescriptorCounts = new double[x.Length];
            int takenDescriptorCount = 0;

            // For all instances
            For(0, x.Length, (i, detector) =>
            {
                if (NumberOfDescriptors > 0 && takenDescriptorCount >= NumberOfDescriptors)
                    return;

                TFeature[] desc = extractor(x[i], detector).Select(p => p.Descriptor).ToArray();

                totalDescriptorCounts[i] = desc.Length;

                if (MaxDescriptorsPerInstance > 0)
                    desc = desc.Sample(MaxDescriptorsPerInstance);

                Interlocked.Add(ref takenDescriptorCount, desc.Length);

                descriptorsPerInstance[i] = desc;
            });

            if (NumberOfDescriptors >= 0 && takenDescriptorCount < NumberOfDescriptors)
            {
                throw new InvalidOperationException("There were not enough descriptors to sample the desired amount " +
                    "of samples ({0}). Please either increase the number of images, or increase the number of ".Format(NumberOfDescriptors) +
                    "descriptors that are sampled from each image by adjusting the MaxSamplesPerImage property ({0}).".Format(MaxDescriptorsPerInstance));
            }

            var totalDescriptors = new TFeature[takenDescriptorCount];
            var totalWeights = weights != null ? new double[takenDescriptorCount] : null;
            int[] instanceIndices = new int[takenDescriptorCount];

            int c = 0, w = 0;
            for (int i = 0; i < descriptorsPerInstance.Length; i++)
            {
                if (descriptorsPerInstance[i] != null)
                {
                    if (weights != null)
                        totalWeights[w++] = weights[i];
                    for (int j = 0; j < descriptorsPerInstance[i].Length; j++)
                    {
                        totalDescriptors[c] = descriptorsPerInstance[i][j];
                        instanceIndices[c] = i;
                        c++;
                    }
                }
            }

            if (NumberOfDescriptors > 0)
            {
                int[] idx = Vector.Sample(NumberOfDescriptors);
                totalDescriptors = totalDescriptors.Get(idx);
                instanceIndices = instanceIndices.Get(idx);
            }

            int[] hist = instanceIndices.Histogram();

            Debug.Assert(hist.Sum() == (NumberOfDescriptors > 0 ? NumberOfDescriptors : takenDescriptorCount));

            this.Statistics = new BagOfWordsStatistics()
            {
                TotalNumberOfInstances = x.Length,
                TotalNumberOfDescriptors = (int)totalDescriptorCounts.Sum(),
                TotalNumberOfDescriptorsPerInstance = NormalDistribution.Estimate(totalDescriptorCounts, new NormalOptions { Robust = true }),
                TotalNumberOfDescriptorsPerInstanceRange = new IntRange((int)totalDescriptorCounts.Min(), (int)totalDescriptorCounts.Max()),

                NumberOfInstancesTaken = hist.Length,
                NumberOfDescriptorsTaken = totalDescriptors.Length,
                NumberOfDescriptorsTakenPerInstance = NormalDistribution.Estimate(hist.ToDouble(), new NormalOptions { Robust = true }),
                NumberOfDescriptorsTakenPerInstanceRange = new IntRange(hist.Min(), hist.Max())
            };

            return learn(totalDescriptors, totalWeights);
        }

        private TModel learn(TFeature[] x, double[] weights)
        {
            this.classifier = this.Clustering.Learn(x, weights);
            this.NumberOfWords = this.classifier.NumberOfClasses;

            return (TModel)this;
        }
        #endregion

        int ITransform.NumberOfInputs
        {
            get { return NumberOfInputs; }
            set { throw new InvalidOperationException("This property is read-only."); }
        }

        int ITransform.NumberOfOutputs
        {
            get { return NumberOfOutputs; }
            set { throw new InvalidOperationException("This property is read-only."); }
        }
    }
}
